# main.py — TP "Marker Tracking" (OpenCV)
#
# Objectif (résumé) :
# - Charger une image de référence (marker.jpg)
# - Ouvrir un flux vidéo (webcam ou fichier)
# - Détecter + décrire le marker (une fois)
# - Pour chaque frame : détecter + décrire, matcher, filtrer, homographie (RANSAC)
# - Projeter les 4 coins du marker, dessiner + masquer en blanc si détection valide
#
# IMPORTANT : ce TP fait de la "redétection + matching" à chaque frame (pas de tracking temporel).

import cv2 as cv
import numpy as np

import MyFeatureDetector as featureDetect
import MyDescriptorExtractor as descExtract
import MyDescriptorMatcher as descMatch

# -----------------------------------------------------------------------------#
# CONFIG
# -----------------------------------------------------------------------------#
MARKER_PATH = "marker.jpg"
VIDEO_SRC = 0  # 0 webcam, ou "video.mp4"

EPSILON_RANSAC = 3.0      # epsilon pour cv.findHomography(..., RANSAC, epsilon)
MIN_INLIERS = 15          # seuil d'inliers pour valider la détection
ESC_KEY = 27              # ESC

# Choix algo (optionnel Q11)
# Pour le sujet de base : ORB
DETECTOR_TYPE = cv.ORB
DESCRIPTOR_TYPE = cv.ORB

# -----------------------------------------------------------------------------#
# Q1 — Lecture du code / structure générale
# -> Ici, le main orchestre le pipeline ; les classes My* encapsulent OpenCV.
# -----------------------------------------------------------------------------#

# -----------------------------------------------------------------------------#
# Étape 1 — Charger le marker
# -----------------------------------------------------------------------------#
marker_bgr = cv.imread(MARKER_PATH, cv.IMREAD_COLOR)
if marker_bgr is None:
    raise RuntimeError(f"Cannot read marker image: {MARKER_PATH}")

marker_gray = cv.cvtColor(marker_bgr, cv.COLOR_BGR2GRAY)

# Coins du marker (pour la projection)
hM, wM = marker_gray.shape[:2]
marker_corners = np.float32([[0, 0], [wM, 0], [wM, hM], [0, hM]]).reshape(-1, 1, 2)

# -----------------------------------------------------------------------------#
# Étape 2 — Détecter + décrire le marker (FP_M, FD_M)
# -----------------------------------------------------------------------------#

# Q2 — Créer un détecteur ORB pour le marker (dans MyFeatureDetector.py)
det_marker = featureDetect.MyFeatureDetector(DETECTOR_TYPE)

# TODO Q3.a : détecter les points d'intérêt du marker
# 1) det_marker.setImage(marker_gray)
# 2) det_marker.detectFeatures()
# 3) kp_marker = det_marker.getFeatures()

# TODO Q3.b : afficher les features détectées sur le marker (utiliser displayFeatures)
# Indication : votre classe MyFeatureDetector doit fournir displayFeatures() ou équivalent.
# ex: marker_kp_vis = det_marker.displayFeatures(marker_bgr)
# cv.imshow("Marker - Features", marker_kp_vis)

# Q4 — Créer un extracteur ORB (dans MyDescriptorExtractor.py)
ext_marker = descExtract.MyDescriptorExtractor(DESCRIPTOR_TYPE)

# TODO Q5.a : calculer les descripteurs du marker (FD_M)
# 1) ext_marker.setImage(marker_gray)
# 2) ext_marker.setFeatures(kp_marker)
# 3) ext_marker.computeDescriptors()
# 4) desc_marker = ext_marker.getDescriptors()

# -----------------------------------------------------------------------------#
# Étape 3 — Ouvrir le flux vidéo
# -----------------------------------------------------------------------------#
cap = cv.VideoCapture(VIDEO_SRC)
if not cap.isOpened():
    raise RuntimeError(f"Cannot open video source: {VIDEO_SRC}")

# -----------------------------------------------------------------------------#
# Étape 4 — Matcher (mise en correspondance)
# -----------------------------------------------------------------------------#
matcher = descMatch.MyDescriptorMatcher()

# -----------------------------------------------------------------------------#
# Boucle principale — pour chaque frame I_t
# -----------------------------------------------------------------------------#
while True:
    ok, frame_bgr = cap.read()
    if not ok:
        break

    frame_gray = cv.cvtColor(frame_bgr, cv.COLOR_BGR2GRAY)

    # -------------------------------------------------------------------------#
    # (1) Détecter des features dans la frame (FP_t)
    # -------------------------------------------------------------------------#

    det_frame = featureDetect.MyFeatureDetector(DETECTOR_TYPE)

    # TODO Q3.c : détecter les points d'intérêt de la frame
    # 1) det_frame.setImage(frame_gray)
    # 2) det_frame.detectFeatures()
    # 3) kp_frame = det_frame.getFeatures()

    # TODO Q3.d : afficher les features détectées sur la frame (option debug)
    # ex: frame_kp_vis = det_frame.displayFeatures(frame_bgr)
    # cv.imshow("Frame - Features", frame_kp_vis)

    # -------------------------------------------------------------------------#
    # (2) Décrire ces features (FD_t)
    # -------------------------------------------------------------------------#

    ext_frame = descExtract.MyDescriptorExtractor(DESCRIPTOR_TYPE)

    # TODO Q5.b : calculer les descripteurs de la frame (FD_t)
    # 1) ext_frame.setImage(frame_gray)
    # 2) ext_frame.setFeatures(kp_frame)
    # 3) ext_frame.computeDescriptors()
    # 4) desc_frame = ext_frame.getDescriptors()

    # -------------------------------------------------------------------------#
    # (3) Matching FD_M <-> FD_t + filtrage des "bons matchs"
    # -------------------------------------------------------------------------#

    found = False
    inlier_count = 0

    # TODO : ajouter une condition robuste pour éviter les None / listes vides
    # if desc_marker is not None and desc_frame is not None and len(kp_marker)>0 and len(kp_frame)>0:
    #     ...

    # TODO Q6 : réaliser le matching et récupérer bestMatches
    # bestMatches = matcher.match(desc_marker, desc_frame)
    bestMatches = []

    # TODO Q7 : afficher le résultat du matching (option)
    # matching_vis = matcher.drawMatchingResults(marker_bgr, kp_marker, frame_bgr, kp_frame, bestMatches)
    # cv.imshow("Matching", matching_vis)

    # -------------------------------------------------------------------------#
    # (4) Estimer l'homographie H_{M->t} avec RANSAC
    # -------------------------------------------------------------------------#

    H = None
    inliers = None

    # TODO Q8.a : si len(bestMatches) >= 4 :
    # 1) construire pts_marker (dans repère marker) et pts_frame (dans repère frame)
    #    pts_marker = np.float32([kp_marker[m.queryIdx].pt for m in bestMatches]).reshape(-1,1,2)
    #    pts_frame  = np.float32([kp_frame[m.trainIdx].pt  for m in bestMatches]).reshape(-1,1,2)
    # 2) appeler cv.findHomography(pts_marker, pts_frame, cv.RANSAC, EPSILON_RANSAC)
    #    H, inliers = cv.findHomography(...)

    # TODO Q8.b : calculer inlier_count
    # inlier_count = int(inliers.sum()) if inliers is not None else 0

    # -------------------------------------------------------------------------#
    # (5) Augmentation : projeter les coins + masquer
    # -------------------------------------------------------------------------#

    # TODO Q9.a : condition de validité :
    # if H is not None and inlier_count >= MIN_INLIERS:
    #     proj = cv.perspectiveTransform(marker_corners, H)
    #     cv.polylines(frame_bgr, [proj.astype(np.int32)], True, (0,255,0), 2)  # debug
    #     cv.fillConvexPoly(frame_bgr, proj.astype(np.int32), (255,255,255))
    #     found = True

    # -------------------------------------------------------------------------#
    # (6) Affichage HUD + fenêtre
    # -------------------------------------------------------------------------#
    if found:
        cv.putText(frame_bgr, f"FOUND (inliers={inlier_count})", (20, 30),
                   cv.FONT_HERSHEY_SIMPLEX, 1.0, (0, 255, 0), 2)
    else:
        cv.putText(frame_bgr, "NOT FOUND", (20, 30),
                   cv.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 255), 2)

    cv.imshow("Tracking", frame_bgr)

    if cv.waitKey(1) == ESC_KEY:
        break

cap.release()
cv.destroyAllWindows()
